import torch
import re
from torch.utils.data import Dataset

## dataset
class IntentDataset(Dataset):
    def __init__(self, features):
        self.features = features
        self.nums = len(self.features)

    def __len__(self):
        return self.nums
    
    def __getitem__(self, item):
        data = {
            "input_ids": self.features[item].input_ids.long(),
            "attention_mask": self.features[item].attention_mask.long(),
            "token_type_ids": self.features[item].token_type_ids.long(),
            "domain_label_ids": self.features[item].domain_label_ids.long(),
            "intent_label_ids": self.features[item].intent_label_ids.long(),
            "slot_label_ids": self.features[item].slot_label_ids.long(),
        }
        return data

# 原始数据样本
class InputExample:
    def __init__(self, set_type, text, domain_label, intent_label, slot_label):
        self.set_type = set_type
        self.text = text
        self.domain_label = domain_label
        self.intent_label = intent_label
        self.slot_label = slot_label

# 数据特征
class InputFeature:
    def __init__(self,
                 input_ids,
                 attention_mask,
                 token_type_ids,
                 domain_label_ids,
                 intent_label_ids,
                 slot_label_ids):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.token_type_ids = token_type_ids
        self.domain_label_ids = domain_label_ids
        self.intent_label_ids = intent_label_ids
        self.slot_label_ids = slot_label_ids

# 读取文件获取数据
def get_examples(path, type):
    examples = []
    with open(path, 'r', encoding='UTF-8') as fp:
        data = eval(fp.read())
    for i, d in enumerate(data):
        text = d['text']
        domain_label = d['domain']
        intent_label = d['intent']
        slot_label = d['slots']
        examples.append(InputExample(type, text, domain_label, intent_label, slot_label))
    return examples

# 样本tokenize：用分词器将一条原始数据样本转换成模型能够接收的数字特征
def convert_example_to_feature(ex_idx, example, tokenizer, config):
    text = example.text
    domain_label = example.domain_label
    domain_label_ids = config.domainlabel2id[domain_label]
    intent_label = example.intent_label # 对应意图标签（文本）
    intent_label_ids = config.intentlabel2id[intent_label] # 对应意图标签（数字）
    slot_label = example.slot_label
    slot_label_ids = [0] * len(text)
    for k, v in slot_label.items():
        re_res = re.finditer(v, text) # 定位文本中的槽值
        for span in re_res:
            start = span.start()
            end = span.end()
            # 首位为B，其他为I
            slot_label_ids[start] = config.slotlabel2id['B-' + k]
            for i in range(start + 1, end):
                slot_label_ids[i] = config.slotlabel2id['I-' + k]
    if len(slot_label_ids) >= config.max_len - 2:
        slot_label_ids = [0] + slot_label_ids + [0]
    else:
        slot_label_ids = [0] + slot_label_ids + [0] + [0] * (config.max_len - len(slot_label_ids) - 2)
    
    text = [i for i in text] # 输入文本
    inputs = tokenizer.encode_plus(
        text=text,
        max_length=config.max_len,
        padding="max_length",
        truncation='only_first',
        return_attention_mask=True,
        return_token_type_ids=True,
    )
    # 转成tensor
    input_ids =  torch.tensor(inputs['input_ids'], requires_grad=False)
    attention_mask =  torch.tensor(inputs['attention_mask'], requires_grad=False)
    token_type_ids =  torch.tensor(inputs['token_type_ids'], requires_grad=False)
    domain_label_ids = torch.tensor(domain_label_ids, requires_grad=False)
    intent_label_ids  = torch.tensor(intent_label_ids, requires_grad=False)
    slot_label_ids = torch.tensor(slot_label_ids, requires_grad=False)
    
    return InputFeature(input_ids, attention_mask, token_type_ids, domain_label_ids, intent_label_ids, slot_label_ids)

# 把原始数据examples全部转换成模型能够接收的数字特征features
def get_features(examples, tokenizer, config):
    features = []
    for i, example in enumerate(examples):
        feature = convert_example_to_feature(i, example, tokenizer, config)
        features.append(feature)
    return features

# 获取模型参数量
def get_total_params(model):
    total_params = sum(p.numel() for p in model.parameters())
    return total_params

# intent评估函数
def get_metrics4intent(golds, preds):
    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
    acc = accuracy_score(golds, preds)
    precision = precision_score(golds, preds, average='micro')
    recall = recall_score(golds, preds, average='micro')
    f1 = f1_score(golds, preds, average='micro')
    return acc, precision, recall, f1

# slot评估函数
def get_metrics4slot(golds, preds):
    from seqeval.metrics import accuracy_score, precision_score, recall_score, f1_score
    acc = accuracy_score(golds, preds)
    precision = precision_score(golds, preds)
    recall = recall_score(golds, preds)
    f1 = f1_score(golds, preds)
    return acc, precision, recall, f1